import copy
import datetime
import enum
import json
import os
import weakref
from pathlib import Path
from queue import Queue
from typing import Dict, List, Optional, Tuple, Union

import psutil
import torch

from tensorrt_llm.logger import logger

from .._torch.pyexecutor.llm_request import LlmResponse
from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank,
                      nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import BaseLlmArgs, PybindMirror
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import _SyncQueue, get_numa_aware_cpu_affinity, logger_debug
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
from .executor import GenerationExecutor, IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocParams, PostprocWorker,
                              PostprocWorkerConfig)
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
from .result import (GenerationResult, LogProbsResult, ResponseWrapper,
                     compute_logprobs)
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
                    is_llm_response)

__all__ = [
    "BaseWorker",
    "_init_hf_modules",
]


def _init_hf_modules():
    """Initialize cached HuggingFace modules for models with trust_remote_code=True.

    This is safe to call multiple times (idempotent) and should be called:
    1. At module import time (for main process and spawned subprocesses)
    2. At worker_main entry (for forked processes or external MPI ranks)

    References: https://github.com/vllm-project/vllm/pull/871
    """
    try:
        from transformers.dynamic_module_utils import init_hf_modules
        init_hf_modules()
        logger.debug("HF modules initialized")
    except ImportError as e:
        logger.warning(f"ImportError initializing HF modules: {e}")
    except Exception as e:
        logger.error(f"Exception initializing HF modules: {e}")


_init_hf_modules()


class BaseWorker(GenerationExecutor):

    class WorkerExit(GeneratorExit):
        pass

    def __init__(
        self,
        engine: Union[Path, Engine],
        executor_config: Optional[tllm.ExecutorConfig] = None,
        batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
        postproc_worker_config: Optional[PostprocWorkerConfig] = None,
        is_llm_executor: Optional[bool] = None,
        hf_model_dir: Optional[Path] = None,
        tokenizer: Optional[TokenizerBase] = None,
        llm_args: Optional[BaseLlmArgs] = None,
    ) -> None:
        postproc_config = postproc_worker_config or PostprocWorkerConfig()
        super().__init__(
            num_postprocess_workers=postproc_config.num_postprocess_workers,
            postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
            is_llm_executor=is_llm_executor,
        )

        # inputs
        self._engine = engine
        self._executor_config = executor_config
        self._batched_logits_processor = batched_logits_processor
        self._postproc_worker_config = postproc_worker_config
        self._is_llm_executor = is_llm_executor
        self._hf_model_dir = hf_model_dir
        self._tokenizer = tokenizer
        self.llm_args = llm_args

        self.engine = None
        self.result_queue: Optional[IpcQueue] = None
        self.postproc_queues: Optional[List[IpcQueue]] = None
        self.rank = mpi_rank()
        self.global_rank = global_mpi_rank()
        # mapping: client_id -> GenerationResult
        self._results: Dict[int, GenerationResult] = {}
        # mapping: client_id from Proxy -> request_id returned from runtime backend
        self._client_id_to_request_id: Dict[int, int] = {}
        self._await_response_helper = AwaitResponseHelper(weakref.proxy(self))
        self._backend = None if llm_args is None else llm_args.backend
        self._is_pytorch_backend = self._backend in ["pytorch", "_autodeploy"]
        self._lora_config = llm_args.lora_config if self._is_pytorch_backend else None

        if global_mpi_size() > 1:
            logger.set_rank(self.global_rank)

    def _configure_affinity(self, device_id):
        '''Probe and configure the CPU affinity of the worker based on NUMA topology.

        Args:
            device_id: The CUDA device ID to determine optimal CPU affinity.

        Note:
            If the process already has constrained affinity, a warning is logged.
            Configuration is handled as follows:
                TLLM_NUMA_AWARE_WORKER_AFFINITY = <unset>
                    -> Affinity is automatically configured if it is unconstrained,
                       and deleted if it is constrained externally by the user.
                TLLM_NUMA_AWARE_WORKER_AFFINITY = 1
                    -> Affinity is unconditionally auto-configured.
                TLLM_NUMA_AWARE_WORKER_AFFINITY = 0 or any other value
                    -> Affinity is unconditionally _not_ auto-configured.
        '''

        # Get the current affinity setting
        pid = os.getpid()
        process = psutil.Process(pid)
        cpu_affinity = process.cpu_affinity()

        all_cpus = list(range(psutil.cpu_count()))

        constrained_affinity = (cpu_affinity != all_cpus)
        numa_aware_affinity = os.environ.get("TLLM_NUMA_AWARE_WORKER_AFFINITY")

        # If affinity is constrained but the user hasn't explicitly
        # requested NUMA-aware affinity, remove the constraints.
        if constrained_affinity:
            logger.warning(
                f"Worker process {pid} is affined to run on the following CPUs: "
                f"{cpu_affinity} (subset of all logical CPUs). This may harm "
                f"performance if set incorrectly.")
            if numa_aware_affinity is None:
                logger.warning(
                    f"Worker process {pid} has constrained CPU affinity "
                    f"but `TLLM_NUMA_AWARE_WORKER_AFFINITY` is not set. "
                    f"Removing CPU affinity constraints.")
                process.cpu_affinity(all_cpus)

        # If affinity is unconstrained and the user hasn't explicitly
        # prohibited it or the user has explicitly requested it, choose the
        # optimal affinity based upon the NUMA topology
        if ((numa_aware_affinity is None and not constrained_affinity)
                or (numa_aware_affinity == "1")):
            process.cpu_affinity(get_numa_aware_cpu_affinity(device_id))
            logger.info(
                f"Worker process {pid} CPU affinity set to "
                f"{process.cpu_affinity()} for optimal NUMA-aware scheduling.")

    def _get_comm_ranks_device_id(self):
        device_id = self.global_rank % torch.cuda.device_count()
        torch.cuda.set_device(device_id)
        # Make sure C++ executor would use same devices/ranks as py_executor
        global_rank = global_mpi_rank()
        comm_ranks = mpi_comm().allgather(global_rank)
        device_ids = mpi_comm().allgather(device_id)

        self._configure_affinity(device_id)

        return comm_ranks, device_ids

    def setup_engine(self):
        """
        Setup the engine for the worker.
        """

        if isinstance(self._engine, list):
            self._engine = self._engine[self.rank]

        def _create_py_executor():
            args = {}
            assert hasattr(
                self.llm_args, "backend"
            ), "llm_args should be with backend in _create_py_executor"
            _ = self._get_comm_ranks_device_id()
            if self._backend == "pytorch":
                from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
                    create_py_executor
                create_executor = create_py_executor
                args["llm_args"] = self.llm_args
                args["checkpoint_dir"] = self._hf_model_dir
                args["tokenizer"] = self._tokenizer
            elif self._backend == "_autodeploy":
                from tensorrt_llm._torch.auto_deploy.llm_args import \
                    LlmArgs as ADLlmArgs
                from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
                    create_autodeploy_executor
                create_executor = create_autodeploy_executor
                assert isinstance(self.llm_args, ADLlmArgs)
                args["ad_config"] = self.llm_args
                args["tokenizer"] = self._tokenizer
            else:
                raise ValueError(f"Unsupported backend config: {self._backend}")

            # Define additional attributes that can be used later, such as in _deduce_max_tokens
            self.mapping = self.llm_args.parallel_config.to_mapping()
            self.checkpoint_loader = None
            if self._backend == "pytorch":
                from tensorrt_llm._torch.pyexecutor.model_loader import \
                    _construct_checkpoint_loader
                self.checkpoint_loader = _construct_checkpoint_loader(
                    self.llm_args.backend, self.llm_args.checkpoint_loader,
                    self.llm_args.checkpoint_format)

            self.max_seq_len = self.llm_args.max_seq_len
            # creare_py_executor may change some fields of llm_args
            _executor = create_executor(**args)
            if _executor.max_seq_len is not None:
                # max_seq_len might be updated by model engine as in create_py_executor
                self.max_seq_len = _executor.max_seq_len
            return _executor

        def _create_engine(executor_config):
            engine = self._engine
            if executor_config is None:
                executor_config = tllm.ExecutorConfig(1)
            executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
                processor_batched=self._batched_logits_processor,
                replicate=False)
            comm_ranks, device_ids = self._get_comm_ranks_device_id()
            executor_config.parallel_config = tllm.ParallelConfig(
                participant_ids=comm_ranks, device_ids=device_ids)

            if isinstance(engine, Engine):
                return tllm.Executor(engine.engine,
                                     json.dumps(engine.config.to_dict(),
                                                cls=ConfigEncoder),
                                     tllm.ModelType.DECODER_ONLY,
                                     executor_config=executor_config,
                                     managed_weights=engine.managed_weights)

            assert not hasattr(executor_config, "backend")
            return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
                                 executor_config)

        self.engine = _create_py_executor(
        ) if self.llm_args is not None else _create_engine(
            self._executor_config)

        self._lora_manager: Optional[LoraManager] = None
        self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
        self._runtime_model_config: Optional[ModelConfig] = None
        if self.rank == 0 and isinstance(self.engine, tllm.Executor):
            if isinstance(self.engine, Engine):
                engine_config = self.engine.config
            else:
                engine_config = EngineConfig.from_json_file(
                    f"{self._engine}/config.json")
            self._runtime_model_config = _engine_config_to_model_config(
                engine_config)
            if engine_config.build_config.plugin_config.lora_plugin:
                # TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
                # (see LoraManager constructor docstring). Getting the peft cache manager from this
                # point in the TRT flow is currently not supported (it's at the CPP
                # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
                # optimization is not available in TRT-python flow.
                self._lora_manager = LoraManager(
                    mapping=engine_config.pretrained_config.mapping,
                    model_config=self._runtime_model_config,
                    cpp_peft_cache_manager=None)
            if engine_config.build_config.max_prompt_embedding_table_size > 0:
                self._prompt_adapter_manager = PromptAdapterManager()

        if self._backend == "pytorch" and self._lora_config is not None:
            from tensorrt_llm._torch.pyexecutor.resource_manager import \
                ResourceManagerType
            peft_cache_manager = self.engine.resource_manager.resource_managers.get(
                ResourceManagerType.PEFT_CACHE_MANAGER)
            self._lora_manager = peft_cache_manager.get_lora_manager()
            lora_model_config = self.engine.model_engine.lora_model_config
            assert lora_model_config is not None
            self._lora_model_config = lora_model_config

    def await_responses(self, timeout: Optional[float] = None) -> list:
        return self.engine.await_responses(timeout=datetime.timedelta(
            seconds=timeout) if timeout is not None else None)

    def fetch_stats(self) -> list:
        if isinstance(self.engine, tllm.Executor):
            iter_stats = self.engine.get_latest_iteration_stats()
            #TODO: Support req stats with TRT engine
            #      This would require ensuring iter and req stats have same size
            return [(iter_stat, None) for iter_stat in iter_stats]
        else:
            return self.engine.get_latest_iteration_stats()

    def fetch_kv_cache_events(self) -> list:
        if isinstance(self.engine, tllm.Executor):
            return self.engine.get_latest_kv_cache_events()
        else:
            return self.engine.get_latest_kv_cache_events()

    def set_result_queue(self, queue):
        """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
        assert self.postproc_queues is None
        self.result_queue = queue

    def set_postproc_queues(self, queues: List["IpcQueue"]):
        """ Set the IPC queues for feeding post-processing processes. """
        assert self.result_queue is None
        self.postproc_queues = queues

    def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
                                    queue: Union[Queue, FusedIpcQueue,
                                                 IntraProcessQueue]):
        assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
        it_result_queue.is_initialized = True
        it_result_queue.queue = queue
        it_result_queue.aqueue = None

    def return_queue(self, client_id: int):
        """ If a centralized result queue is registered (used for communication with the proxy)
            send the message there.
            Otherwise, push the result directly in the GenerationResult queue.
        """
        if self.result_queue is not None:
            return self.result_queue
        return self._results[client_id].queue

    def abort_request(self, client_id: int) -> None:
        # NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
        if self.engine.can_enqueue_requests():
            request_id = self._client_id_to_request_id.get(client_id, None)
            if request_id is None:
                logger.warning(
                    f"Request of client_id {client_id} is finished, cannot abort it."
                )
                return
            self.engine.cancel_request(request_id)

    def _engine_response_callback(self, response: tllm.Response):
        return response

    def _has_background_error(self) -> bool:
        return not self._error_queue.empty()

    def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
        bck_error = self._error_queue.get_nowait()
        assert isinstance(bck_error, Exception)
        return ErrorResponse(response.client_id, str(bck_error),
                             response.request_id)

    def start(self):
        raise NotImplementedError(
            "start method is not implemented in BaseWorker")

    def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
        """Returns True if the adapter was loaded by this call, False if it was already loaded"""
        adapter_id = str(lora_request.adapter_id)
        newly_loaded_uids = self._lora_manager.load_from_ckpt(
            [lora_request.path],
            model_config=self._runtime_model_config if
            self._runtime_model_config is not None else self._lora_model_config,
            uids=[adapter_id],
            ckpt_source=lora_request.ckpt_source)
        return adapter_id in newly_loaded_uids

    def _load_prompt_adapter(self,
                             prompt_adapter_request: PromptAdapterRequest):
        self._prompt_adapter_manager.load_from_ckpt(
            [prompt_adapter_request.local_path],
            model_config=self._runtime_model_config,
            uids=[str(prompt_adapter_request.adapter_id)])

    def _enqueue_request(self,
                         request: GenerationRequest,
                         result_wait_queue=None) -> int:
        assert request.id is not None
        py_lora_path = None
        if self._lora_manager is not None and request.lora_request is not None:
            adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
                request.lora_request.adapter_id)
            self._load_lora_adapter(request.lora_request)
            uid = str(request.lora_request.adapter_id)
            lora_config = tllm.LoraConfig(
                task_id=request.lora_request.adapter_id,
                weights=self._lora_manager.cpp_lora_weights[uid]
                if not adapter_in_cache else None,
                config=self._lora_manager.cpp_lora_config[uid])
            py_lora_path = request.lora_request.lora_path
        else:
            lora_config = None

        prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
        prompt_tuning_config = None
        if request.prompt_adapter_request is not None:
            self._load_prompt_adapter(request.prompt_adapter_request)
            uid = str(request.prompt_adapter_request.adapter_id)
            prompt_tuning_config = tllm.PromptTuningConfig(
                self._prompt_adapter_manager.uid_to_weights[uid])
            vocab_size = self._runtime_model_config.vocab_size
            pa_length = prompt_tuning_config.embedding_table.size(0)
            prompt_token_ids = list(range(
                vocab_size, vocab_size + pa_length)) + prompt_token_ids

        # MULTIMODAL
        # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field
        # except `multimodal_input` as it needs to go through the C++ runtime.
        multimodal_input = None
        if request.multimodal_params is not None and request.multimodal_params.has_content(
        ):
            if request.multimodal_params.multimodal_input is not None:
                multimodal_input = tllm.MultimodalInput(
                    multimodal_hashes=request.multimodal_params.
                    multimodal_input.multimodal_hashes,
                    multimodal_positions=request.multimodal_params.
                    multimodal_input.multimodal_positions,
                    multimodal_lengths=request.multimodal_params.
                    multimodal_input.multimodal_lengths)
            # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field
            request.multimodal_params.multimodal_input = None

        context_phase_params = None
        request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
        if request.disaggregated_params is not None:
            assert (
                not self._is_pytorch_backend
                or self.engine.kv_cache_transceiver is not None
                or request.disaggregated_params.request_type
                == "context_and_generation"
            ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
            request_type = request.disaggregated_params.get_request_type()
            if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
                context_phase_params = request.disaggregated_params.get_context_phase_params(
                )

        if self._is_pytorch_backend:
            if not self.llm_args.disable_overlap_scheduler:
                is_disaggregated = self.engine.kv_cache_transceiver is not None
                if is_disaggregated and (
                        request_type
                        == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
                    raise ValueError(
                        "Context only requests are not supported in pytorch backend when overlap is enabled."
                    )

        assert request.id is not None

        def _deduce_max_tokens(request: GenerationRequest,
                               executor_config: tllm.ExecutorConfig,
                               llm_args: Optional[BaseLlmArgs] = None) -> int:
            # deduce max_tokens when it's not set by user
            max_tokens = request.sampling_params.max_tokens
            query_token_len = len(
                request.query_token_ids) if request.query_token_ids else 0

            cp_size = 1
            max_seq_len = None
            if llm_args is not None:
                # deduce max_tokens by llm args
                assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
                if hasattr(self,
                           "mapping") and self.mapping.cp_size is not None:
                    cp_size = self.mapping.cp_size
                max_seq_len = getattr(self, "max_seq_len", None)
            else:
                # deduce max_tokens by executor config
                if hasattr(executor_config, "mapping"
                           ) and executor_config.mapping.cp_size is not None:
                    cp_size = executor_config.mapping.cp_size
                max_seq_len = getattr(executor_config, "max_seq_len", None)
            if max_seq_len is None:
                logger.warning("`default_max_tokens` cannot be deduced")
                if max_tokens is None:
                    raise ValueError(
                        "`max_tokens` must be set when `default_max_tokens` cannot be deduced"
                    )
                else:
                    # use max_tokens if can't deduce default_max_tokens
                    return max_tokens
            if executor_config is not None:
                assert (
                    len(prompt_token_ids) <= executor_config.max_seq_len
                ), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
            splited_prompt_len = int(len(prompt_token_ids) / cp_size)
            default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
            if default_max_tokens <= 0:
                # Raise error on `default_max_tokens` not enough, since max_tokens should be less than `default_max_tokens``
                raise ValueError(
                    f"`default_max_tokens` ({default_max_tokens}) must be greater than 0, "
                    f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
                    f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
                )

            # default_max_tokens is the biggest available value
            if max_tokens is None:
                return default_max_tokens
            elif max_tokens > default_max_tokens and default_max_tokens > 0:
                logger.warning(
                    f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
                    f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
                )
                return default_max_tokens
            elif max_tokens <= 0:
                raise ValueError(
                    f"`max_tokens` ({max_tokens}) must be greater than 0")
            else:
                return max_tokens

        try:
            executor_request = tllm.Request(
                client_id=request.id,
                input_token_ids=prompt_token_ids,
                max_tokens=_deduce_max_tokens(
                    request,
                    self._executor_config if not self.llm_args else None,
                    self.llm_args),
                streaming=request.streaming,
                sampling_config=request.sampling_params._get_sampling_config(),
                end_id=-1 if request.sampling_params.ignore_eos else
                request.sampling_params.end_id,
                pad_id=request.sampling_params.pad_id,
                output_config=request.sampling_params._get_output_config(
                    is_pytorch_backend=self._is_pytorch_backend),
                # Beam search enforces return_all_generated_tokens=True regardless of the passed value
                return_all_generated_tokens=False,
                # convert python config into pybind config
                lookahead_config=PybindMirror.maybe_to_pybind(
                    request.sampling_params.lookahead_config),
                guided_decoding_params=request.sampling_params.
                _get_guided_decoding_params(),
                bad_words=request.sampling_params._get_bad_words(),
                stop_words=request.sampling_params._get_stop_words(),
                embedding_bias=request.sampling_params.embedding_bias,
                lora_config=lora_config,
                prompt_tuning_config=prompt_tuning_config,
                multimodal_input=multimodal_input,
                # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
                multimodal_embedding=None,
                mrope_config=None,
                logits_post_processor_name=(
                    tllm.Request.BATCHED_POST_PROCESSOR_NAME
                    if request.sampling_params.apply_batched_logits_processor
                    else None),
                logits_post_processor=None if self._is_pytorch_backend else
                request.sampling_params.logits_processor,
                kv_cache_retention_config=request.kv_cache_retention_config,
                context_phase_params=context_phase_params,
                type=request_type,
                cache_salt_id=request.cache_salt_id)
            executor_request.py_num_logprobs = request.sampling_params.logprobs
            executor_request.py_lora_path = py_lora_path

            if self._is_pytorch_backend and request.multimodal_params is not None:
                if request.multimodal_params.multimodal_data is not None:
                    # NOTE: Deserialize SharedTensor handle to actual tensor
                    request.multimodal_params.to_tensor("multimodal_data")
                    executor_request.py_multimodal_data = request.multimodal_params.multimodal_data

            if self._is_pytorch_backend and request.sampling_params.logits_processor:
                # For PyTorch backend, we attach logits processors as a dynamic Python attribute
                # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
                lp = request.sampling_params.logits_processor
                executor_request.py_logits_post_processors = lp if isinstance(
                    lp, list) else [lp]

            executor_request.py_scheduling_params = None
            if self._is_pytorch_backend and request.scheduling_params is not None:
                executor_request.py_scheduling_params = request.scheduling_params

            if request.arrival_time is not None:
                executor_request.py_arrival_time = request.arrival_time

            if request.query_token_ids is not None:
                # pytorch star attention workflow
                # a workaround to avoid public interface update
                if self._is_pytorch_backend and result_wait_queue is not None:
                    req_id = self.engine.enqueue_request(
                        executor_request,
                        request.query_token_ids,
                        result_wait_queue=result_wait_queue)
                else:
                    req_id = self.engine.enqueue_request(
                        executor_request, request.query_token_ids)
            else:
                if self._is_pytorch_backend and result_wait_queue is not None:
                    req_id = self.engine.enqueue_request(
                        executor_request, result_wait_queue=result_wait_queue)
                else:
                    req_id = self.engine.enqueue_request(executor_request)
            return req_id
        except Exception as e:
            raise RequestError(str(e)) from e

    def submit(self, request: GenerationRequest) -> GenerationResult:
        """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
        self.start()

        if self.rank != 0:
            raise RuntimeError(
                "Only rank 0 can submit requests.\n"
                "To fix this, ensure that the llm.generate(...) method is "
                "guarded with the `if __name__ == '__main__':` block.")

        client_id = request.id if request.id is not None else self._get_next_client_id(
        )
        if request.id is None:
            request.set_id(client_id)

        logprob_params = self._get_logprob_params(request)

        result = GenerationResult(
            request,
            background_error_handler=self._handle_background_error,
            executor=self,
            disaggregated_params=request.disaggregated_params,
            logprob_params=logprob_params)

        self._results[client_id] = result

        request_id = self._enqueue_request(request)
        # request_id returned from backend is necessary for the abort_request method.
        self._client_id_to_request_id[client_id] = request_id

        self._handle_background_error()

        return result

    def shutdown(self):
        if self.doing_shutdown:
            return
        else:
            self.doing_shutdown = True

        if self.engine is not None and self.engine.can_enqueue_requests():
            self.engine.shutdown()
            self.engine = None

    # Define a Callable to join iteration and request stats
    @staticmethod
    def _stats_serializer(
            stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
        iteration_stats, req_stats = stats
        stats_dict = json.loads(iteration_stats.to_json_str())

        if req_stats is not None and len(req_stats) > 0:
            stats_dict["requestStats"] = []
            for req_stat in req_stats:
                stats_dict["requestStats"].append(
                    json.loads(req_stat.to_json_str()))

        # Convert back to JSON string
        return json.dumps(stats_dict)

    # Define a Callable to serialize KV cache events
    @staticmethod
    def _kv_cache_events_serializer(events) -> str:
        from .._utils import KVCacheEventSerializer
        return json.dumps(KVCacheEventSerializer.serialize(events))

    def _pop_result(self, client_id: int):
        self._results.pop(client_id, None)
        self._client_id_to_request_id.pop(client_id, None)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> bool:
        self.shutdown()
        return exc_type is None or exc_type == self.WorkerExit

    def __del__(self):
        self.shutdown()


class AwaitResponseHelper:
    ''' Multiple-implementations for await_response for performance. '''

    class HandlerKind(enum.Enum):
        unknown = 0
        single_process_worker = 1
        ipc_batched = 2

    def __init__(self, worker: "BaseWorker"):
        # TODO: make worker weakref
        self.worker = worker
        self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
        self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
        # The error responses when submit request failed will be put here
        self.temp_error_responses = Queue()

    def responses_handler(self, responses: List[tllm.Response]):
        HandlerKind = AwaitResponseHelper.HandlerKind

        if self.handler_kind is HandlerKind.unknown:
            if not (self.worker.result_queue is not None
                    or self.worker.postproc_queues is not None):
                logger_debug(f"creating await_response helper for Worker\n",
                             color="yellow")
                # When ExecutorBindingWorker is used in the main process
                # aka the single process mode
                self.handler_kind = HandlerKind.single_process_worker
            elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
                # The ExecutorBindingProxy is used
                logger_debug(f"creating await_response helper for IPC\n",
                             color="yellow")
                self.handler_kind = HandlerKind.ipc_batched
            else:
                raise NotImplementedError

        match self.handler_kind:
            case HandlerKind.single_process_worker:
                return self.handle_for_worker(responses)
            case HandlerKind.ipc_batched:
                return self.handle_for_ipc_batched(responses)
            case _:
                raise NotImplementedError

    def __call__(self, timeout: Optional[float] = None) -> bool:
        ''' This method should be called by a ManagedThread. '''
        timeout = timeout or 0.1
        responses = self.worker.engine.await_responses(
            timeout=datetime.timedelta(seconds=timeout))
        # filter since The _engine_response_callback may return None
        responses = list(
            filter(
                lambda _: _,
                [self.worker._engine_response_callback(r) for r in responses]))

        # append the error responses to the temp_error_responses
        while not self.temp_error_responses.empty():
            responses.append(self.temp_error_responses.get())

        with nvtx_range_debug(f"await_response-{len(responses)}",
                              color="red",
                              category="Worker"):
            self.responses_handler(responses)
        return True

    def handle_for_worker(self, responses: List[tllm.Response]) -> None:
        ''' Return the responses to asyncio.event_loop. '''
        event_loop = None
        async_queues = []
        for response in responses:
            assert response is not None
            queue = self.worker.return_queue(response.client_id)

            if not response.has_error():
                response = _maybe_wrap_response(self.worker, response,
                                                self.worker._is_pytorch_backend)

            # For AsyncQueue.sync_q, we will batch the events to avoid too many
            # event notifications, thus put without wait here.
            if isinstance(queue, _SyncQueue):
                global_tracer().log_instant("worker-rsp.put")
                queue.put_nowait(response)
                async_queues.append(queue)
                # all the loops are identical
                event_loop = event_loop or queue.loop
            else:
                queue.put(response)

            if response.has_error() or response.result.is_final:
                self.worker._pop_result(response.client_id)

        # Notify the events in bulk for performance.
        if async_queues:
            _SyncQueue.notify_many(event_loop, async_queues)

    def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
        ''' Perform the IPC in batch explicitly. '''
        postproc_batches = [
            []
            for _ in range(self.worker.postproc_config.num_postprocess_workers)
        ] if self.enable_postprocprocess_parallel else None
        rsp_batch = [] if not self.enable_postprocprocess_parallel else None

        for response in responses:

            if isinstance(response, ErrorResponse):
                pass  # send ErrorResponse directly
            elif self.worker._has_background_error():
                response = self.worker._create_error_response(response)
            elif response.has_error():
                # Convert to ErrorResponse, because tllm.Response cannot be
                # serialized when it has error.
                response = ErrorResponse(response.client_id, response.error_msg,
                                         response.request_id)
            else:
                response = _maybe_wrap_response(self.worker, response,
                                                self.worker._is_pytorch_backend)

            _send_rsp(self.worker,
                      response,
                      postproc_batches=postproc_batches,
                      rsp_batch=rsp_batch)

        if postproc_batches:
            for wid, batch in enumerate(postproc_batches):
                if batch:
                    self.worker.postproc_queues[wid].put(batch)

        if rsp_batch:
            self.worker.result_queue.put(rsp_batch)


def _get_params_for_first_rsp(
        worker,
        client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
    res = worker._results.get(client_id, None)
    assert res is not None
    if not res._params_transmitted:
        res._params_transmitted = True
        return res.sampling_params, res.postproc_params
    return None, None


def _compute_pytorch_prompt_logprobs(
        generation_result: GenerationResult,
        response: LlmResponse) -> Optional[LogProbsResult]:
    """Compute prompt logprobs for PyTorch backend (cached when streaming) """
    logprob_params = generation_result._logprob_params  # should be present and non None
    assert logprob_params is not None
    if generation_result._streaming:
        cached = getattr(generation_result, '_cached_prompt_logprobs', None)
        if cached is not None:
            return LogProbsResult(
                prompt=cached, generation=None
            )  # generation logprobs, if requested, is provided directly in response.result.log_probs from the sampler.
    context_logits = response.result.context_logits
    assert context_logits is not None, "context_logits cannot be None when prompt_logprobs is requested."
    logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, None,
                                       context_logits, None, None)
    if generation_result._streaming:
        generation_result._cached_prompt_logprobs = logprobs_result.prompt

    return logprobs_result


def _get_logprobs(worker,
                  response: Union[tllm.Response, LlmResponse],
                  is_pytorch_backend=False) -> Optional[LogProbsResult]:
    """Compute logprobs from response logits when needed.

    Logprobs provenance varies by backend:
    - PyTorch: Generation logprobs computed in sampler, only prompt logprobs computed here
    - TRT: Both prompt and generation logprobs computed here from logits
    """

    logprobs_result = None
    generation_result = worker._results.get(response.client_id, None)

    if not generation_result:
        return None

    logprob_params = getattr(generation_result, "_logprob_params", None)
    if logprob_params:
        if is_pytorch_backend:
            if not logprob_params.prompt_logprobs:
                # PyTorch: generation logprobs computed in sampler, no post-processing needed
                return None
            else:
                logprobs_result = _compute_pytorch_prompt_logprobs(
                    generation_result, response)

                if logprob_params.drop_context_logits:
                    response.clear_context_logits()

                return logprobs_result

        # TRT backend: compute both prompt and generation logprobs from logits
        logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
                                           logprob_params.logprobs,
                                           response.result.context_logits,
                                           response.result.generation_logits,
                                           response.result.output_token_ids[0])

        if logprob_params.drop_context_logits:
            response.clear_context_logits()

        if logprob_params.drop_generation_logits:
            response.clear_generation_logits()

    if response.result.is_final:
        generation_result.clear_logprob_params()

    return logprobs_result


def _send_rsp(
        worker,
        response: Union[tllm.Response, ResponseWrapper, ErrorResponse],
        postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
        rsp_batch: Optional[List[tllm.Response]] = None):
    # if postproc_batches is set, append to batch instead of putting to IpcQueue

    if worker.result_queue is not None:
        if rsp_batch is not None:
            rsp_batch.append(response)
        else:
            worker.result_queue.put(response)
    else:
        sampling_params, postproc_params = _get_params_for_first_rsp(
            worker, response.client_id)
        inp = PostprocWorker.Input(
            response,
            # sampling_params is necessary for creating fake GenerationResult
            # instances in the postproc processes. They are for incremental
            # detokenize. They should be transmitted only once for each
            # Request.
            sampling_params=sampling_params,
            postproc_params=postproc_params,
            streaming=worker._results.get(response.client_id, None)._streaming)

        pid = response.client_id % worker.postproc_config.num_postprocess_workers

        if not postproc_batches:
            # Group the responses into buckets for the postprocessing steps.
            # Bucketing is used instead of random dispatching because the
            # incremental detokenization during postprocessing relies on the
            # prior CompletionOutput of a given request.
            worker.postproc_queues[pid].put(inp)
        else:
            postproc_batches[pid].append(inp)

    # Eliminate the finished GenerationRequest instances timely, which may
    # take considerable memory.
    if is_llm_response(response):
        if response.has_error() or response.result.is_final:
            worker._pop_result(response.client_id)
    elif isinstance(response, ErrorResponse):
        worker._pop_result(response.client_id)
    else:
        raise ValueError(f"Unknown response type: {response}")


def _get_metrics_dict(
        response: tllm.Response) -> dict[RequestEventTiming, float]:
    req_perf_metrics, metrics_dict = None, {}
    res = response.result
    if res:
        if hasattr(res, '_result'):
            if result := res.get_result():
                req_perf_metrics = result.request_perf_metrics
        else:
            req_perf_metrics = res.request_perf_metrics
        if req_perf_metrics and req_perf_metrics.timing_metrics:
            metrics_dict = {
                RequestEventTiming.ARRIVAL_TIME:
                req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
                RequestEventTiming.FIRST_TOKEN_TIME:
                req_perf_metrics.timing_metrics.first_token_time.total_seconds(
                ),
                RequestEventTiming.FIRST_SCHEDULED_TIME:
                req_perf_metrics.timing_metrics.first_scheduled_time.
                total_seconds(),
                RequestEventTiming.LAST_TOKEN_TIME:
                req_perf_metrics.timing_metrics.last_token_time.total_seconds(),
                RequestEventTiming.KV_CACHE_TRANSFER_START:
                req_perf_metrics.timing_metrics.kv_cache_transfer_start.
                total_seconds(),
                RequestEventTiming.KV_CACHE_TRANSFER_END:
                req_perf_metrics.timing_metrics.kv_cache_transfer_end.
                total_seconds(),
                RequestEventTiming.KV_CACHE_SIZE:
                req_perf_metrics.timing_metrics.kv_cache_size,
            }
    return metrics_dict


def _maybe_wrap_response(
        worker,
        response: tllm.Response,
        is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:

    logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
    req_perf_metrics = _get_metrics_dict(response)
    if logprobs_result or req_perf_metrics:
        response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
    return response
